package org.test4j.module.spring;

import org.junit.jupiter.api.extension.ExtensionContext;
import org.test4j.Context;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import static org.test4j.mock.faking.util.ReflectUtility.doThrow;

/**
 * Spring环境设置
 *
 * @author darui.wu
 */
public class SpringEnv {
    private static Map<Class, Boolean> classIsSpring = new ConcurrentHashMap<>();

    private static String[] springTestAnnotations = {
        "org.springframework.boot.test.context.SpringBootTest",
        "org.springframework.test.context.ContextConfiguration"
    };

    public static boolean isSpringEnv() {
        return isSpringEnv(Context.currTestClass());
    }

    public static void setSpringEnv(Class<?> clazz) {
        classIsSpring.put(clazz, isSpringTest(clazz));
    }


    public static boolean isSpringEnv(Class clazz) {
        return clazz == null || classIsSpring.get(clazz) == null ? false : classIsSpring.get(clazz);
    }

    private static boolean isSpringTest(Class aClass) {
        if (aClass == null) {
            return false;
        }
        for (String annotation : springTestAnnotations) {
            boolean hasAnnotation = hasAnnotation(aClass, annotation);
            if (hasAnnotation) {
                return true;
            }
        }
        return false;
    }

    static Map<String, Class> HasAnnotation = new HashMap<>(5);

    private static boolean hasAnnotation(Class objectClass, String annotation) {
        Class annotationClass = getAnnotationClass(annotation);
        if (annotationClass == null) {
            return false;
        } else {
            Annotation instance = getClassLevelAnnotation(annotationClass, objectClass);
            return instance != null;
        }
    }

    private static Class<?> getAnnotationClass(String annotation) {
        if (!HasAnnotation.containsKey(annotation)) {
            try {
                Class clazz = Class.forName(annotation);
                HasAnnotation.put(annotation, clazz);
                return clazz;
            } catch (ClassNotFoundException e) {
                HasAnnotation.put(annotation, null);
                return null;
            }
        }
        return HasAnnotation.get(annotation);
    }


    private static <T extends Annotation> T getClassLevelAnnotation(Class<T> annotationClass, Class clazz) {
        Class superClass = clazz;
        while (!Object.class.equals(superClass)) {
            T annotation = (T) clazz.getAnnotation(annotationClass);
            if (annotation != null) {
                return annotation;
            }
            superClass = superClass.getSuperclass();
        }
        return null;
    }


    /**
     * 获得当前测试类spring容器中名称为beanName的spring bean
     *
     * @param beanName
     * @return
     */
    public static <T> T getBeanByName(String beanName) {
        return SpringInit.getBeanByName(beanName);
    }

    public static <T> T getBeanByType(Class beanType) {
        return SpringInit.getBeanByType(beanType);
    }

    public static void injectSpringBeans(Object testedObject) {
        if (!SpringEnv.isSpringEnv()) {
            return;
        }
        SpringInit.injectSpringBeans(testedObject);
    }

    /**
     * 用来在test4j初始化之前工作<br>
     * 比如spring加载前的mock工作等
     *
     * @param test 测试类实例
     */
    public static void invokeSpringInitMethod(Object test) {
        Method[] methods = test.getClass().getDeclaredMethods();
        for (Method method : methods) {
            if (method.getParameterCount() == 0 && method.getAnnotation(BeforeSpringContext.class) != null) {
                method.setAccessible(true);
                try {
                    method.invoke(test);
                } catch (Exception e) {
                    doThrow(e);
                }
            }
        }
    }

    /**
     * 在测试spring容器启动前后执行
     * 1. 执行@BeforeSpringContext 方法
     * 2. 初始化测试实例注入
     * 3. 注册spring容器
     *
     * @param testInstance
     * @param context
     * @throws Exception
     */
    public static void doSpringInitial(Object testInstance, ExtensionContext context) throws Exception {
        SpringInit.doSpringInitial(testInstance, context);
    }
}